load("//bazel:api.bzl", "mojo_filecheck_test", "mojo_test")

_EXTRA_MOJO_TEST_CONSTRAINTS = {
    "test_dump_sass.mojo": ["//:nvidia_gpu"],
    "test_function_attributes.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),
    "test_device_stream.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: GEX-2554
    "test_device_event.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: GEX-2555
    "test_external_shared_mem.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2397
    "test_device_context_sub_buffer.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: GEX-2562
}

[
    mojo_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"] + _EXTRA_MOJO_TEST_CONSTRAINTS.get(src, []),
        deps = [
            "//max:internal_utils",
            "//max:kv_cache",
            "//max:linalg",
            "//max:nn",
            "//max:quantization",
            "@mojo//:stdlib",
        ],
    )
    for src in glob(
        ["**/*.mojo"],
        exclude = [
            "test_function_error.mojo",
            "test_function_error_sync_mode.mojo",
        ],
    )
]

mojo_filecheck_test(
    name = "test_function_error.mojo.test",
    srcs = ["test_function_error.mojo"],
    expect_crash = True,
    tags = ["gpu"],
    target_compatible_with = ["//:nvidia_gpu"],
    deps = [
        "//max:internal_utils",
        "//max:kv_cache",
        "//max:linalg",
        "//max:nn",
        "//max:quantization",
        "@mojo//:stdlib",
    ],
)

mojo_filecheck_test(
    name = "test_function_error_sync_mode.mojo.test",
    srcs = ["test_function_error_sync_mode.mojo"],
    env = {
        "MODULAR_DEVICE_CONTEXT_SYNC_MODE": "1",
    },
    expect_crash = True,
    tags = ["gpu"],
    target_compatible_with = ["//:nvidia_gpu"],
    deps = [
        "//max:internal_utils",
        "//max:kv_cache",
        "//max:linalg",
        "//max:nn",
        "//max:quantization",
        "@mojo//:stdlib",
    ],
)
